import functools

from .ct import Ct
from config import SEG_CACHE_PATH
from utils.disk import getCache


# In order to get decent performance out of LunaDataset, we'll need to invest in
# some on-disk caching. This will allow us to avoid having to read an entire CT
# scan from disk for every sample. Doing so would be prohibitively slow! Make
# sure you're paying attention to bottlenecks in your project and doing what you
# can to optimize them once they start slowing you down. We're kind of jumping
# the gun here since we haven't demonstrated that we need caching here. Without
# caching, the LunaDataset is easily 50 times slower!

raw_cache = getCache(SEG_CACHE_PATH)


# We use a few different caching methods here. First, we're caching the getCt
# return value in memory so that we can repeatedly ask for the same Ct instance
# without having to reload all of the data from disk. That's a huge speed
# increase in the case of repeated requests, but we're only keeping one CT in
# memory, so cache misses will be frequent if we're not careful about access
# order.
@functools.lru_cache(1, typed=True)
def getCt(series_uid):
    return Ct(series_uid)


# After our cache is populated, getCt won't ever be called. These values are
# cached to disk using the Python library diskcache. Tt's much, much faster to
# read in 2**15 float32 values from disk than it is to read in 2**25 int16
# values, convert to float32, and then select a 2**15 subset. From the second
# pass through the data forward, I/O times for input should drop to insignificance.
@raw_cache.memoize(typed=True)
def getCtRawCandidate(series_uid, center_xyz, width_irc):
    ct = getCt(series_uid)
    ct_chunk, pos_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
    ct_chunk.clip(-1000, 1000, ct_chunk)
    return ct_chunk, pos_chunk, center_irc


# For validation, we'll need to produce one sample per slice of CT that has an
# entry in the positive mask, for each validation CT we have. Since different
# CT scans can have different slice counts, we're going to introduce a new
# function that caches the size of each CT scan and its positive mask to disk.
# We need this to be able to quickly construct the full size of a validation
# set without having to load each CT at Dataset initialization.
#
# Populating this data will also take place during the prepcache.py script,
# which we must run once before we start any model training.
@raw_cache.memoize(typed=True)
def getCtSampleSize(series_uid):
    ct = Ct(series_uid)
    return int(ct.hu_a.shape[0]), ct.positive_indexes
